package com.tca.common.custom.inject.plugins;

import com.google.common.collect.Lists;
import com.tca.common.core.utils.ValidateUtils;
import com.tca.common.core.utils.helper.Assert;
import com.tca.common.custom.inject.annotation.FactoryCode;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.env.Environment;

import java.util.List;

/**
 * @author zhoua
 * @date 2024/9/7 12:18
 */
@Slf4j
@ConditionalOnProperty(value = "tca.custom.inject.enabled", havingValue = "true", matchIfMissing = true)
public class FactoryCodeBeanDefinitionRegistryPostProcessor implements BeanDefinitionRegistryPostProcessor,
        ApplicationContextAware {

    private static final String FACTORY_CODE_PARAM = "factoryCode";

    private final List<String> injectImplList = Lists.newArrayList();

    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry beanDefinitionRegistry) throws BeansException {
        String factoryCodeDetermined = System.getProperty(FACTORY_CODE_PARAM);
        Assert.notNull(factoryCodeDetermined, "system param factoryCode can not be null");

        List<String> beanDefinitionNamesNeedRemoved = Lists.newArrayList();

        for (String beanDefinitionName : beanDefinitionRegistry.getBeanDefinitionNames()) {
            BeanDefinition beanDefinition = beanDefinitionRegistry.getBeanDefinition(beanDefinitionName);
            String beanClassName = beanDefinition.getBeanClassName();
            if (ValidateUtils.isEmpty(beanClassName)) {
                continue;
            }

            Class<?> clazz;
            try {
                clazz = Class.forName(beanClassName);
            } catch (ClassNotFoundException e) {
                log.error("initial java class error, className = {}", beanClassName, e);
                throw new RuntimeException(e);
            }

            FactoryCode factoryCode = clazz.getAnnotation(FactoryCode.class);
            if (ValidateUtils.isEmpty(factoryCode)) {
                continue;
            }

            String value = factoryCode.value();
            if (!value.equalsIgnoreCase(factoryCodeDetermined)) {
                // not matched, need removed
                beanDefinitionNamesNeedRemoved.add(beanDefinitionName);
                continue;
            }

            // matched, determined current impl or default impl
            String defaultImplFullName = getDefaultImplFullName(clazz);
            if (ValidateUtils.isNotEmpty(injectImplList) && injectImplList.contains(defaultImplFullName)) {
                // determine to use default
                beanDefinitionNamesNeedRemoved.add(beanDefinitionName);
                continue;
            }
            String defaultImplBeanDefinitionName = getDefaultImplBeanDefinitionName(clazz);
            if (ValidateUtils.isNotEmpty(defaultImplBeanDefinitionName)) {
                beanDefinitionNamesNeedRemoved.add(defaultImplBeanDefinitionName);
            }


        }

        log.info("BeanDefinition need removed: {}", beanDefinitionNamesNeedRemoved);
        beanDefinitionNamesNeedRemoved.forEach(beanDefinitionRegistry::removeBeanDefinition);
    }

    /**
     * 获取默认实现的类名
     * @param clazz
     * @return
     */
    private String getDefaultImplFullName(Class<?> clazz) {
        Class<?> superclass = clazz.getSuperclass();
        if (ValidateUtils.isEmpty(superclass)) {
            return null;
        }

        return superclass.getName();
    }

    /**
     * 获取默认实现的 BeanDefinitionName
     * @param clazz
     * @return
     */
    private String getDefaultImplBeanDefinitionName(Class<?> clazz) {
        Class<?> superclass = clazz.getSuperclass();
        if (ValidateUtils.isEmpty(superclass)) {
            return null;
        }

        String simpleName = superclass.getSimpleName();
        char firstChar = simpleName.charAt(0);
        char updatedFirstChar = Character.toLowerCase(firstChar);
        String remainder = simpleName.length() > 1 ? simpleName.substring(1) : "";

        return updatedFirstChar + remainder;
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {

    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        Environment environment = applicationContext.getEnvironment();
        initInjectImplList(environment);
    }

    /**
     * init injectImplList
     * @param environment
     */
    private void initInjectImplList(Environment environment) {
        log.info("start to init injectImplList...");

        for (int i = 0; i < Integer.MAX_VALUE; i++) {
            String impl = environment.getProperty(String.format("tca.custom.inject.impl[%d]", i));
            if (ValidateUtils.isEmpty(impl)) {
                break;
            }
            injectImplList.add(impl);
        }

        log.info("init injectImplList end, injectImplList: {}", injectImplList);
    }
}
